# -*- coding: utf-8 -*- #

# -----------------------------------------------------------------------
# File Name:    model.py
# Version:      ver1_0
# Created:      2024/06/17
# Description:  本文件定义了CustomNet类，用于定义神经网络模型
#               ★★★请在空白处填写适当的语句，将CustomNet类的定义补充完整★★★
# -----------------------------------------------------------------------
import torch
from torch import nn


class CustomNet(nn.Module):
    """自定义神经网络模型。
    请完成对__init__、forward方法的实现，以完成CustomNet类的定义。
    """

    def __init__(self):
        """初始化方法。
        在本方法中，请完成神经网络的各个模块/层的定义。
        请确保每层的输出维度与下一层的输入维度匹配。
        """
        super(CustomNet, self).__init__()

        # START----------------------------------------------------------
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # 输入通道3，输出通道16，卷积核大小3x3
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)       # 2x2池化层
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) # 输入通道16，输出通道32，卷积核大小3x3
        self.fc1 = nn.Linear(32 * 16 * 16, 128)                            # 全连接层，输入维度32*16*16，输出维度128
        self.fc2 = nn.Linear(128, 10)                                      # 全连接层，输入维度128，输出维度10（假设有10个类别）
        self.relu = nn.ReLU()                                              # ReLU激活函数
        # END------------------------------------------------------------

    def forward(self, x):
        """前向传播过程。
        在本方法中，请完成对神经网络前向传播计算的定义。
        """
        # START----------------------------------------------------------
        x = self.relu(self.conv1(x))  # 第一层卷积 + 激活函数
        x = self.pool(x)              # 第一层池化
        x = self.relu(self.conv2(x))  # 第二层卷积 + 激活函数
        x = self.pool(x)              # 第二层池化
        x = x.view(-1, 32 * 16 * 16)  # 展平操作，将张量重塑为2D张量，batch_size x (32*16*16)
        x = self.relu(self.fc1(x))    # 第一层全连接层 + 激活函数
        x = self.fc2(x)               # 第二层全连接层
        # END------------------------------------------------------------
        return x


if __name__ == "__main__":
    # 测试
    from dataset import CustomDataset
    from torchvision.transforms import ToTensor

    c = CustomDataset('./images/train.txt', './images/train', ToTensor)
    net = CustomNet()                                # 实例化
    x = torch.unsqueeze(c[10]['image'], 0)           # 模拟一个模型的输入数据
    print(net.forward(x))                            # 测试forward方法

